use inkwell::context::Context;
use inkwell::builder::Builder;
use inkwell::module::Module;
use inkwell::values::{IntValue, FunctionValue};
use inkwell::types::BasicMetadataTypeEnum;
use crate::ast::*;
pub struct Codegen<'ctx> {
pub context: &'ctx Context,
pub module: Module<'ctx>,
pub builder: Builder<'ctx>,
}
impl<'ctx> Codegen<'ctx> {
pub fn new(context: &'ctx Context, module_name: &str) -> Self {
let module = context.create_module(module_name);
let builder = context.create_builder();
Self { context, module, builder }
}
pub fn compile_program(&self, program: &Program) -> Result<(), String> {
for item in &program.items {
if let TopLevel::Func(func) = item {
self.compile_function(func)?;
}
}
Ok(())
}
fn compile_function(&self, func: &FuncDecl) -> Result<(), String> {
let i64_type = self.context.i64_type();
let param_types: Vec<BasicMetadataTypeEnum> = func
.params
.iter()
.map(|_| i64_type.into())
.collect();
let fn_type = i64_type.fn_type(¶m_types, false);
let function = self.module.add_function(&func.name, fn_type, None);
let entry = self.context.append_basic_block(function, "entry");
self.builder.position_at_end(entry);
use std::collections::HashMap;
let mut locals: HashMap<String, IntValue> = HashMap::new();
for (i, param) in func.params.iter().enumerate() {
let arg = function.get_nth_param(i as u32).unwrap().into_int_value();
let alloca = self.builder.build_alloca(i64_type, ¶m.name);
self.builder.build_store(alloca, arg);
locals.insert(param.name.clone(), self.builder.build_load(i64_type, alloca, ¶m.name).into_int_value());
}
self.lower_block(&func.body, &mut locals, function)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
let ret_val = i64_type.const_int(0, false);
self.builder.build_return(Some(&ret_val));
}
Ok(())
}
fn lower_block(&self, block: &crate::ast::mod::Block, locals: &mut std::collections::HashMap<String, IntValue<'ctx>>, function: FunctionValue<'ctx>) -> Result<(), String> {
for stmt in &block.statements {
self.lower_stmt(stmt, locals, function)?;
}
Ok(())
}
fn lower_stmt(&self, stmt: &Stmt, locals: &mut std::collections::HashMap<String, IntValue<'ctx>>, function: FunctionValue<'ctx>) -> Result<(), String> {
let i64_type = self.context.i64_type();
match stmt {
Stmt::VarDecl(v) => {
let alloca = self.builder.build_alloca(i64_type, &v.name);
let init_val = self.lower_expr(&v.init, locals)?;
self.builder.build_store(alloca, init_val);
locals.insert(v.name.clone(), self.builder.build_load(i64_type, alloca, &v.name).into_int_value());
}
Stmt::Return(Some(expr)) => {
let ret_val = self.lower_expr(expr, locals)?;
self.builder.build_return(Some(&ret_val));
}
Stmt::Expr(expr) => {
self.lower_expr(expr, locals)?;
}
Stmt::If(if_stmt) => {
let cond_val = self.lower_expr(&if_stmt.condition, locals)?;
let zero = i64_type.const_int(0, false);
let cond_bool = self.builder.build_int_compare(inkwell::IntPredicate::NE, cond_val, zero, "ifcond");
let then_bb = self.context.append_basic_block(function, "then");
let else_bb = self.context.append_basic_block(function, "else");
let merge_bb = self.context.append_basic_block(function, "ifcont");
self.builder.build_conditional_branch(cond_bool, then_bb, else_bb);
self.builder.position_at_end(then_bb);
self.lower_block(&if_stmt.then_branch, locals, function)?;
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
self.builder.build_unconditional_branch(merge_bb);
}
self.builder.position_at_end(else_bb);
if let Some(else_block) = &if_stmt.else_branch {
self.lower_block(else_block, locals, function)?;
}
if self.builder.get_insert_block().unwrap().get_terminator().is_none() {
self.builder.build_unconditional_branch(merge_bb);
}
self.builder.position_at_end(merge_bb);
}
Stmt::While(while_stmt) => {
let cond_bb = self.context.append_basic_block(function, "whilecond");
let body_bb = self.context.append_basic_block(function, "whilebody");
let after_bb = self.context.append_basic_block(function, "whileafter");
self.builder.build_unconditional_branch(cond_bb);
self.builder.position_at_end(cond_bb);
let cond_val = self.lower_expr(&while_stmt.condition, locals)?;
let zero = i64_type.const_int(0, false);
let cond_bool = self.builder.build_int_compare(inkwell::IntPredicate::NE, cond_val, zero, "whilecond");
self.builder.build_conditional_branch(cond_bool, body_bb, after_bb);
self.builder.position_at_end(body_bb);
self.lower_block(&while_stmt.body, locals, function)?;
self.builder.build_unconditional_branch(cond_bb);
self.builder.position_at_end(after_bb);
}
_ => return Err("Unsupported statement in codegen".to_string()),
}
Ok(())
}
fn lower_expr(&self, expr: &Expr, locals: &std::collections::HashMap<String, IntValue<'ctx>>) -> Result<IntValue<'ctx>, String> {
match expr {
Expr::Literal(lit) => match lit {
Literal::Int(i) => Ok(self.context.i64_type().const_int(*i as u64, true)),
Literal::Bool(b) => Ok(self.context.i64_type().const_int(if *b { 1 } else { 0 }, false)),
_ => Err("Only integer/bool literals are supported in codegen".to_string()),
},
Expr::Ident(name) => {
locals.get(name).cloned().ok_or_else(|| format!("Undefined variable {}", name))
}
Expr::Binary(lhs, op, rhs) => {
let left = self.lower_expr(lhs, locals)?;
let right = self.lower_expr(rhs, locals)?;
match op {
BinOp::Add => Ok(self.builder.build_int_add(left, right, "addtmp")),
BinOp::Sub => Ok(self.builder.build_int_sub(left, right, "subtmp")),
BinOp::Mul => Ok(self.builder.build_int_mul(left, right, "multmp")),
BinOp::Eq => {
let cmp = self.builder.build_int_compare(inkwell::IntPredicate::EQ, left, right, "eqtmp");
Ok(self.builder.build_int_z_extend(cmp, self.context.i64_type(), "booltmp"))
},
BinOp::Lt => {
let cmp = self.builder.build_int_compare(inkwell::IntPredicate::SLT, left, right, "lttmp");
Ok(self.builder.build_int_z_extend(cmp, self.context.i64_type(), "booltmp"))
},
_ => Err("Unsupported binary operator in codegen".to_string()),
}
}
_ => Err("Unsupported expression in codegen".to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::mod::Parser;
use inkwell::targets::{InitializationConfig, Target};
#[test]
fn test_codegen_simple() {
Target::initialize_all(&InitializationConfig::default());
let context = Context::create();
let codegen = Codegen::new(&context, "test_module");
let source = "def add(a: int, b: int) -> int { return a + b; }";
let mut parser = Parser::new(source);
let program = parser.parse_program().expect("parse failed");
codegen.compile_program(&program).expect("codegen failed");
assert!(codegen.module.get_function("add").is_some());
}
#[test]
fn test_var_decl() {
Target::initialize_all(&InitializationConfig::default());
let context = Context::create();
let codegen = Codegen::new(&context, "test_module_var");
let source = "def foo() -> int { let x = 42; return x; }";
let mut parser = Parser::new(source);
let program = parser.parse_program().expect("parse failed");
codegen.compile_program(&program).expect("codegen failed");
assert!(codegen.module.get_function("foo").is_some());
}
#[test]
fn test_if_stmt() {
Target::initialize_all(&InitializationConfig::default());
let context = Context::create();
let codegen = Codegen::new(&context, "test_module_if");
let source = "def abs(x: int) -> int { if x < 0 { return 0 - x; } return x; }";
let mut parser = Parser::new(source);
let program = parser.parse_program().expect("parse failed");
codegen.compile_program(&program).expect("codegen failed");
assert!(codegen.module.get_function("abs").is_some());
}
#[test]
fn test_while_stmt() {
Target::initialize_all(&InitializationConfig::default());
let context = Context::create();
let codegen = Codegen::new(&context, "test_module_while");
let source = "def loop(n: int) -> int { while n > 0 { n = n - 1; } return n; }";
let source = "def loop(n: int) -> int { while n > 0 { let x = n; } return 0; }";
let mut parser = Parser::new(source);
let program = parser.parse_program().expect("parse failed");
codegen.compile_program(&program).expect("codegen failed");
assert!(codegen.module.get_function("loop").is_some());
}
}